# Exponential families
include("binary_search.jl");

struct Gaussian
    σ2;
end

# Default Gaussian, simplify computations
Gaussian() = Gaussian(1);

# Sample
sample(rng, expfam::Gaussian, μ) = μ + sqrt(expfam.σ2)*randn(rng);

# KL divergence
d(expfam::Gaussian, μ, λ) = (μ-λ)^2/(2*expfam.σ2);

# KL derivatives
dµ_d(expfam::Gaussian, μ, λ) = (µ-λ)/expfam.σ2;
dλ_d(expfam::Gaussian, μ, λ) = (λ-µ)/expfam.σ2;

# upward and downward confidence intervals (box confidence region)
dup(expfam::Gaussian, μ, v) = μ + sqrt(2*expfam.σ2*v);
ddn(expfam::Gaussian, μ, v) = μ - sqrt(2*expfam.σ2*v);


struct Bernoulli
end

rel_entr(x, y) = x == 0 ? 0. : x * log(x / y);
dx_rel_entr(x, y) = x == 0 ? 0. : log(x / y);
dy_rel_entr(x, y) = -x / y;

d(expfam::Bernoulli, μ, λ) = max(0, rel_entr(μ, λ) + rel_entr(1 - μ, 1 - λ));
dµ_d(expfam::Bernoulli, μ, λ) = dx_rel_entr(μ, λ) - dx_rel_entr(1 - μ, 1 - λ);
dλ_d(expfam::Bernoulli, μ, λ) = dy_rel_entr(μ, λ) - dy_rel_entr(1 - μ, 1 - λ);
invh(expfam::Bernoulli, μ, x) = 2 * μ / (1 - x + sqrt((x - 1)^2 + 4 * x * μ));
sample(rng, expfam::Bernoulli, μ) = rand(rng) ≤ μ;

function dup(expfam::Bernoulli, μ, v)
    μ == 1 ? 1. : binary_search(λ -> d(expfam, μ, λ) - v, μ, 1);
end

function ddn(expfam::Bernoulli, μ, v)
    μ == 0 ? 0. : binary_search(λ -> v - d(expfam, μ, λ), 0, μ);
end

